{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f68180fcf50>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import collections\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "torch.set_printoptions(edgeitems=2)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = ['airplane','automobile','bird','cat','deer',\n",
    "               'dog','frog','horse','ship','truck']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar/cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd79420a71f449fa81d6f1ac2c2db66f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../data/cifar/cifar-10-python.tar.gz to ../data/cifar/\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets, transforms\n",
    "data_path = '../data/cifar/'\n",
    "cifar10 = datasets.CIFAR10(\n",
    "    data_path, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                             (0.2470, 0.2435, 0.2616))\n",
    "    ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "cifar10_val = datasets.CIFAR10(\n",
    "    data_path, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                             (0.2470, 0.2435, 0.2616))\n",
    "    ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {0: 0, 2: 1}\n",
    "class_names = ['airplane', 'bird']\n",
    "cifar2 = [(img, label_map[label])\n",
    "          for img, label in cifar10\n",
    "          if label in [0, 2]]\n",
    "cifar2_val = [(img, label_map[label])\n",
    "              for img, label in cifar10_val\n",
    "              if label in [0, 2]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "connected_model = nn.Sequential(\n",
    "            nn.Linear(3072, 1024),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(1024, 512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512, 128),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(128, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = nn.Conv2d(3,16,kernel_size=3)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 30, 30]))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 输入图像时，第一个维度是batch\n",
    "img,_ = cifar2[0]\n",
    "\n",
    "output = conv(img.unsqueeze(0)) \n",
    "img.unsqueeze(0).shape,output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f680435e610>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAXI0lEQVR4nO2dXYyeZZnHf1dbxtZ+SD/od0tbwJQtCiUTshHcuDEYFjXogUYODJuYrQeSaOLBGvdADsnGj3iwMakLETeuH4kaOSC7a9BIiAlpIWxpKdAWSmk7nRZaSluw049rD+bFjHXu/z28M/O+k73/v2QyM+81z3Nf7/08/3ne9/0/13VHZmKM+f/PrH4nYIzpDRa7MY1gsRvTCBa7MY1gsRvTCBa7MY0wZzIbR8RdwPeB2cC/Z+aD6u8HBgZy7ty548ZqFuD73ve+Yuzy5cvF2MjISFfbAVx11VXF2Jw55ambPXu23K/i4sWLxdg777xTjKn5U7nW4pOxZmfNKl9Luj1map+g87106VIxVjovoX48L1y4UIzV8u2WiBj38bfffpuRkZFxg12LPSJmA/8G3AkcBnZExKOZ+Xxpm7lz5zI4ODhurCa8DRs2FGPnz58vxg4ePFiM/elPf5JjLl++vKvYBz7wgWJMnXAAr7/+ejG2e/fuYkydcMuWLZNjquei9ls64d5FCUjN/WuvvVaMDQwMyDHVP8uzZ88WYx/84AeLsYULF8oxh4eHi7F58+YVY5P5R1C6EP3hD38oj9f1aHAbsD8zX87MEeBnwD2T2J8xZhqZjNjXAGP/BR/uPGaMmYFM5j37eK/h/uoNU0RsA7aBft9tjJleJnNlPwysG/P7WuDolX+UmdszczAzB2vvt4wx08dkxL4DuCEiNkbEAPAF4NGpScsYM9V0/TI+My9GxP3AfzNqvT2cmXvUNgMDA6xbt27cmLJbABYtWlSMKWvk1KlTxdjQ0JAcc+nSpcXY+9///mLszJkzxdiuXbvkmCpfZZGpt0i1T5PVJ8bqE/eTJ0/K/b799tvFmJq/a665Ru5XoVwd9TyvvfbaYky5CrUxuz03a6+CS3qRFqzcY4XMfAx4bDL7MMb0Bt9BZ0wjWOzGNILFbkwjWOzGNILFbkwjWOzGNMKkrLf3yoULFzh27Ni4MVUpBrB///5iTHmLyrNVJay1uLov4Ny5c13lA7oSSlV0qf0uWLBAjqmqBpWHrPIBXWWmjre6n2DJkiVyzLVr1xZjb775ZjG2fv36YuzQoUNyTFXB1+29EbW5rVUcjoev7MY0gsVuTCNY7MY0gsVuTCNY7MY0gsVuTCP01HrLzKKVc/r0abmtssFUOaBqtjh//nw5prKkVLnkK6+80tU+AbZs2VKMKYtR2Uq1uVVlmKpBZq1rrbID1dw/99xzxdiRI0fkmLfeemsxtmLFimJMlbGq4wm6lFedm2r+arZwKV91LH1lN6YRLHZjGsFiN6YRLHZjGsFiN6YRLHZjGqGn1tvAwECxKumNN96Q2yr7SFUPnThxohjbuHGjHFOtc6aqwTZt2lSMHT58WI6prDllqyjrTVlDoCuoVFffWqWi6iCrjrfq+lvraKvWbFPPZeXKlXK/itWrVxdj6ph1W20I5UpGVTXpK7sxjWCxG9MIFrsxjWCxG9MIFrsxjWCxG9MIk7LeIuIgcAa4BFzMzEH195lZtBtURRJoG0xZS08//XQxVrPB1IKIym5RNo56HqCrut55551iTFk1yo6pobat2UOqaaKqplPzvmHDBjmmqnpTC24qu081zgQ9R8pee+utt4qx2mKcpSaX07awY4e/z8zXp2A/xphpxC/jjWmEyYo9gf+JiKcjYttUJGSMmR4m+zL+9sw8GhHLgd9GxAuZ+cTYP+j8E9gG+vZJY8z0Mqkre2Ye7Xw/DvwauG2cv9memYOZOaha/xhjppeuxR4R8yNi4bs/A58Adk9VYsaYqWUyL+NXAL/uVEzNAf4zM/9rSrIyxkw5XYs9M18Gbp6qRGr+s1qcUJVoqoUAa51eX3311WLsmWee6WpM5c8DXH311cWYKilVnnatU6nqgKq84FOnTsn9qvJOdc+Aep6f/OQn5Zif/vSni7GXXnqpGFMdZA8cOCDH3LdvXzGm5lbNQe3eiNLnX8pnt/VmTCNY7MY0gsVuTCNY7MY0gsVuTCNY7MY0Qk+7y86aNau4oN9kFh9U9obqPKsWfQRYtWpVMaa6uaqOtkePHpVjqhJOtVCiKjedzG3Kar+qTBVgZGSkGFOlvOqYXX/99XLMzZs3F2NPPvlkMfb66+XCzdq5qc4FZaWqbrdq4VAoH1N3lzXGWOzGtILFbkwjWOzGNILFbkwjWOzGNEJPrbcLFy4UbSllXwAsXry4qzGV/VPrjlpahBJ0FZ6yP1SHU9BVS6qLrqp6O3funBxTVcWVuphOZL+y0+mc8qmnqhjVvAMcOnSoGHvxxReLsdrCooo1a9YUY+pcKNnQULeFS3Ok5tVXdmMawWI3phEsdmMawWI3phEsdmMawWI3phF6ar3NmTOn2ExQWTygbSfVj15ZNTUb7NixY8VYt5bK0qVL5ZjLly8vxlSDQtUYslZB1e2Ch8rmAl39p+ZBVa4pOw/gj3/8YzG2a9euYkydQ1u2bJFjqueiKviU9auq5aB8vFUFqK/sxjSCxW5MI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjRC1WePiIeBTwHHM/OmzmNLgJ8DG4CDwOczU6/yx6hHWvIWVTkpaC9Yed7K51Q+MOiOo90uwFjz2ZX3qrxgVaZaK+U9fvx4Mabuf1DllKCPi1rIc+PGjcWYuocBYM+ePcXY8PBwMaaO59mzZ+WYqnuvOv9Ud9laKW/pXFDnwUSu7D8C7rrisW8Aj2fmDcDjnd+NMTOYqtgz8wng5BUP3wM80vn5EeAzU5yXMWaK6fY9+4rMHALofC/e4xkR2yJiZ0TsrK2HboyZPqb9A7rM3J6Zg5k5qN6/GGOml27FPhwRqwA638uf8BhjZgTdiv1R4L7Oz/cBv5madIwx08VErLefAh8DlkXEYeBbwIPALyLiS8Ah4HOTTUSV5oEu01SdVVVp7GRQ1pGypJQ1ArpUVT1PRc16W7hwYTGmupzWyjA3bdrUVU5qzNr8qQVAV65cWYwpe23Hjh1yTHVc1Jhqoc5un6eyqKtiz8x7C6GP17Y1xswcfAedMY1gsRvTCBa7MY1gsRvTCBa7MY3Q0+6yitrCjqr6SlkfygZTVWQ1VPWaqliq3TKs7jJUVpeyLmu2ppoHZWXVOr0qG0jF1IKRqkIPtCW6YsWKYkwdM2WHgu7Aqyon1Tlf00Op07DK1Vd2YxrBYjemESx2YxrBYjemESx2YxrBYjemEXpqvV28eLFonUSE3FZZUiMjI8WYqq6qVRYtWbKkq5iyq2ooO0vZiMqarC2aqRpDrlmzphhTjTVB22BHjhwpxlTTSLXYZg1VAansyVozVGVdqvP21Klqj9YipSpQpSNf2Y1pBIvdmEaw2I1pBIvdmEaw2I1pBIvdmEaw2I1phJ767HPnzmXz5s3jxpTXC7q8U/nlyuut+Zyq1FIt5qf88Fq5pCrXVbHTp08XY7WFCdXc7969uxhT9xqAXriwW2rdgg8ePFiMvfDCC8WYKn+dTImw6or8xhtvFGO1exhWrVo17uPq/gVf2Y1pBIvdmEaw2I1pBIvdmEaw2I1pBIvdmEaYyMKODwOfAo5n5k2dxx4A/gk40fmzb2bmY7V9LVq0iDvvvLOrROfPn9/VdsrqOnDggNx2//79xZgqG1VWYG1xRlWuq2Jqv8omBG1PqpLSWrmpWjBSzZGyumqLVCqbUZ0LyrJT8w56AUvVhVh1kFVzB/WS8PGYyJX9R8Bd4zz+vcy8pfNVFboxpr9UxZ6ZTwAne5CLMWYamcx79vsjYldEPBwRi6csI2PMtNCt2H8AXAfcAgwB3yn9YURsi4idEbFT3dJpjJleuhJ7Zg5n5qXMvAz8ELhN/O32zBzMzMHa/b7GmOmjK7FHxNi78D8LlKsljDEzgolYbz8FPgYsi4jDwLeAj0XELUACB4EvT2SwiChWCF26dEluW1rIDrpfQPCll16SY6oF+7q1AmvWkVpgUL0yUvutLcCoqt6URXbypP7cVll6CmUV1ir4VCWe6hKrFoysWW9qW2WJqqq3mrVWWjBSVQVWxZ6Z947z8EO17YwxMwvfQWdMI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjRCT7vLXrhwgaNHj44bq3mZ3XrpqhR1x44dckyV00033VSMqRJN5cmCzld1MV2+fHkxVutoq7roqhVp1T0BoLvsqv2q+wJqPru630Ddp3DHHXcUY7WOtuo+BbWKqzqetc7HJY9e+fq+shvTCBa7MY1gsRvTCBa7MY1gsRvTCBa7MY3QU+vt/PnzvPzyy+PGaqWfygZT1pHarmZJKetIleSqEs1a2aeah9LcAZw4caIYU2WqoOdPWTm1Mkw1D6rrqtpOWVmgOwYri3bdunXFWG2BSrXoqLIK1X7V8YTyea0W4vSV3ZhGsNiNaQSL3ZhGsNiNaQSL3ZhGsNiNaYSeWm+XLl2qVvOUUNVgylJRXWlrqA6yyjZRi2GojqKgbRyVj5oDVZUF2u5TNtj58+flfpU9qWy71atXF2O14zlv3rxiTNlrpW6tULdoV65cWYypSjtlr9UWAC0tJrlv377iNr6yG9MIFrsxjWCxG9MIFrsxjWCxG9MIFrsxjTCRhR3XAT8GVgKXge2Z+f2IWAL8HNjA6OKOn8/Mqq9WajSoKsxAV4uphoCq4utDH/qQHFNtqywpZWVt3LhRjqlsMmVXLVy4UO5XoWxN1chyMtV0ao5UPrXmj8qCVMdMWWQ1u09Zc6r5qGo4WWusOV0NJy8CX8/MG4G/Bb4SEX8DfAN4PDNvAB7v/G6MmaFUxZ6ZQ5n5TOfnM8BeYA1wD/BI588eAT4zXUkaYybPe3rPHhEbgK3AU8CKzByC0X8IQPn1njGm70xY7BGxAPgl8LXM1PcP/uV22yJiZ0TsnMytq8aYyTEhsUfEVYwK/SeZ+avOw8MRsaoTXwWMu9RJZm7PzMHMHFT3LRtjppeq2GP0I+CHgL2Z+d0xoUeB+zo/3wf8ZurTM8ZMFROpersd+CLwXEQ823nsm8CDwC8i4kvAIeBz05OiMWYqqIo9M58ESgbvx9/LYLNmzSp2B611DVWepPKmle/45ptvyjFvvvnmYmzt2rVd7bfWkVWVAL/yyivFmJof5S+D9q6XLVvWVQxgeHi4GFM+sjqetcUk1aKQ6r4AVY5b63ysjrc6/9S8K38eyvelqPPLd9AZ0wgWuzGNYLEb0wgWuzGNYLEb0wgWuzGN0NPuspcvXy52JK2VuCpLQVkqqtPrwYMH5ZjKBvvwhz9cjCkbTHWIBV3Kq+yhydhVqhRVWVm1DqjHjh0rxtRxUZZdrdOrsnDV/Kmuv7Xn+dprrxVj6nirslpl7ar9Kh35ym5MI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjRCT603KFcQ1RYJVAvvKQtNVSzVqt527dpVjCkLqLZfxZYtW4qxI0eOFGOqa63qEAva1lQ24uHDh+V+n3/++WJMVeINDQ0VYy+88IIcUy3eqGw5ZQUuXbpUjqnOhVdffbUYu+6664qxj370o3LMUgdeZcH6ym5MI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjRCT623ixcvFquLlGUAcOLEiWJMVR2pJo033nijHFPZYAq1MGGtamvr1q3F2KpVq4oxZfcp2xJ040g1pppbgJMnTxZjK1euLMZKi39CvQJt/fr1xZiy19QcHT8+7pIIf0ZZc6qppLIu1fkO5cUmVdWkr+zGNILFbkwjWOzGNILFbkwjWOzGNILFbkwjTGQV13UR8fuI2BsReyLiq53HH4iIIxHxbOfr7ulP1xjTLRPx2S8CX8/MZyJiIfB0RPy2E/teZn57ooONjIwUvcVad1nlVypfVnWeXbJkiRxz0aJFxZi6L2DNmjXFmFrQEHQ3UpXvU089VYyp5wG6xFV1elX3E4D26Es+cS2fwcHBrsd88cUXizF1fh06dEiOqeZh8+bNctsStcUkS/OnugFPZBXXIWCo8/OZiNgLlM9mY8yM5D29Z4+IDcBW4N3LyP0RsSsiHo6IxVOcmzFmCpmw2CNiAfBL4GuZ+RbwA+A64BZGr/zfKWy3LSJ2RsROdSufMWZ6mZDYI+IqRoX+k8z8FUBmDmfmpcy8DPwQuG28bTNze2YOZuagev9sjJleJvJpfAAPAXsz87tjHh/7Schngd1Tn54xZqqYyKfxtwNfBJ6LiGc7j30TuDcibgESOAh8eVoyNMZMCRP5NP5JYDwv5LH3Otjs2bOL1pKyz2DUtivRbYlmbZFFtfCeyldZMaq0E/SCfqo7qurkWitxVWWj8+bNK8ZqC2NeffXVxZiyiNT81Tr3KttOfWak7FJlE9bGVN151Zg1W7P0lljl4jvojGkEi92YRrDYjWkEi92YRrDYjWkEi92YRuhpd9nMLNofyqIA2LRpUzGmFoVcuHBhMXbmzBk5ZqkTLmhLSm1Xq6BS1VfKllPzs2DBgq7HvOaaa4qxG264Qe5XWZvKelMVX7VjpmzEc+fOFWNqocnVq1fLMVW+av5UrqraEMp2oNqnr+zGNILFbkwjWOzGNILFbkwjWOzGNILFbkwj9NR6g3KjRmX/gLZxVENF1eBRLQgJusJK2VnKHlKVa6ArrPbs2VOMqUUUP/KRj8gx1cKEqmKuVpml4mrMDRs2FGO1hR3VYpPd2n21SkV1zJYvX16MqQUja1Z0qbpNNUL1ld2YRrDYjWkEi92YRrDYjWkEi92YRrDYjWkEi92YRuipzz5nzpxiJ1jlOdbiAwMDXW331ltvyTFVSaTyM9evX1+MKQ8ZtL967NixYkx1ev3d734nx7z++uuLMdVFt1Zuqrx/1S1469atxZiaW9ClquoeB3U8ayxeXF75TN2Pcfr06WJMlW2Du8saYwQWuzGNYLEb0wgWuzGNYLEb0wgWuzGNEKrsb8oHizgBvDrmoWWAXnWwtzgfzUzLB2ZeTv3O59rMHLelbU/F/leDR+zMzMG+JXAFzkcz0/KBmZfTTMtnLH4Zb0wjWOzGNEK/xb69z+NfifPRzLR8YOblNNPy+TN9fc9ujOkd/b6yG2N6RF/EHhF3RcSLEbE/Ir7RjxyuyOdgRDwXEc9GxM4+5fBwRByPiN1jHlsSEb+NiH2d7+Xyqt7k80BEHOnM07MRcXcP81kXEb+PiL0RsScivtp5vC9zJPLp2xzV6PnL+IiYDbwE3AkcBnYA92bm8z1N5C9zOggMZmbf/NGI+DvgLPDjzLyp89i/Aicz88HOP8XFmfnPfcznAeBsZn67Fzlckc8qYFVmPhMRC4Gngc8A/0gf5kjk83n6NEc1+nFlvw3Yn5kvZ+YI8DPgnj7kMaPIzCeAKwvA7wEe6fz8CKMnUz/z6RuZOZSZz3R+PgPsBdbQpzkS+cxY+iH2NcDY1RkO0/9JSuB/IuLpiNjW51zGsiIzh2D05ALKKw70jvsjYlfnZX7P3laMJSI2AFuBp5gBc3RFPjAD5mg8+iH28Vpp9NsSuD0zbwX+AfhK5yWs+Wt+AFwH3AIMAd/pdQIRsQD4JfC1zNSthvqTT9/nqEQ/xH4YWDfm97XA0T7k8Wcy82jn+3Hg14y+1ZgJDHfeG777HlH37ppmMnM4My9l5mXgh/R4niLiKkaF9ZPM/FXn4b7N0Xj59HuOFP0Q+w7ghojYGBEDwBeAR/uQBwARMb/zAQsRMR/4BLBbb9UzHgXu6/x8H/CbPubyrpje5bP0cJ5itLnaQ8DezPzumFBf5qiUTz/nqEpm9vwLuJvRT+QPAP/SjxzG5LIJ+N/O155+5QP8lNGXfRcYffXzJWAp8Diwr/N9SZ/z+Q/gOWAXoyJb1cN87mD07d4u4NnO1939miORT9/mqPblO+iMaQTfQWdMI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjTC/wG8PfbYQNGylQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(output[0,0].detach(),cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f67fc742990>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAV2ElEQVR4nO2dXYyd1XWGnxXjAWyM7fEfxlg1IC5aBRWqkVWJqkoVJaEoEnARFC4iKkV1IgUpSLlolF6ES1TlR7moIjkFhVSBJBKJghTUJkKRSG4iDCL8xBRoTGHwMDM2GP8AZjxevZhDNDFnv3v8nZlzJtnvI41m5qzZ315nn++d8/N+a+3ITIwxf/58aNQJGGOGg8VuTCNY7MY0gsVuTCNY7MY0gsVuTCNcMMjgiLgR+BawBviPzLxH/f369etz06ZNfWNnz56Vc33oQ+X/SyoWEZ3GAaxdu7YYu+CCbktXu58q3tUmVWtQi69Zs2ZFjltb+xK1NXjvvfeKsbm5uWJMrXvtfqqczpw502nO+fl5OWcpfurUKd59992+CXcWe0SsAf4d+BgwCTweEQ9n5u9KYzZt2sTnP//5vrF33nlHznfxxRcXYxs2bCjG1Em1fv16OefOnTuLsW3bthVj6sE/efKknPPUqVPFWNd/BEqwoNdWrdHY2Jg87oUXXtgppv6Rvvvuu3LOV155pRibnp4uxtT5V1s/9Q/m2LFjxdiJEyc6xQCOHz/e9/af/exnxTGDvIzfC7yUmb/PzPeAHwA3D3A8Y8wKMojYdwGvLvp9snebMWYVMojY+70v+MBryYjYFxEHIuKAeolqjFlZBhH7JLB70e9XAIfP/aPM3J+ZE5k5UXuPbIxZOQYR++PANRFxZUSMAZ8GHl6etIwxy03nT+Mz80xE3An8NwvW232Z+VxtXMnGUJ8I11CfcL/99tvFWM3GOXr0aDF25ZVXFmPj4+PFmPrkFvSn8eqTaPWJ8SWXXCLnVJ9+q8elZj8qJ0RZUsrqUuOg/Ck1wOzsbDGmPo2/9NJL5ZzqPDp9+nSnmDpvVVw5NgP57Jn5CPDIIMcwxgwHX0FnTCNY7MY0gsVuTCNY7MY0gsVuTCNY7MY0wkDW2/mSmUWftOYFKy9TlS6qUsHa5bvKB1UesqpYqnn7teq/EuvWrSvGalVbauxFF11UjNU8b3VNQddy3VqJsHrMVD5qnDq/QJ9jXb30WnVkKa7Wx8/sxjSCxW5MI1jsxjSCxW5MI1jsxjSCxW5MIwzVejtz5gxHjhzpG6uVfqpySmWDqRLNmo2jLJW33nqrGFO2Uq1TqcpJ3Rdlr3Xt5Aradqo1f1S2k8pXjavZpcqy6tpFt2Yxdi1jVeunxkH53FTnnp/ZjWkEi92YRrDYjWkEi92YRrDYjWkEi92YRhiq9TY/P88bb7zRN1az3lT1lbLlNm7cWIypLrCg9yNTVo2y7GoVVMo66Vr5V7OrVEfWrl1gQVt+ao1mZmaKsVdffbUYAzh06FCnOdU5VKtEVMdV9lrN0lOU8pWbaXaezRjzJ4XFbkwjWOzGNILFbkwjWOzGNILFbkwjDGS9RcTLwAlgHjiTmRPq7+fn56s2UImuNoWqIlONFkHba2NjY8XY2rVri7GaXaVsHlXRpay32tqp+6nsvtpmnGodlNV6+PAHdv7+A889p/cOVWPVluGbN2+Wx1Wox1TZcipWq8gsPWYql+Xw2f8hM/vXrRpjVg1+GW9MIwwq9gR+HhFPRMS+5UjIGLMyDPoy/obMPBwR24FfRMTzmfnY4j/o/RPYB/X3eMaYlWOgZ/bMPNz7PgP8BNjb52/2Z+ZEZk6oD7WMMStLZ7FHxPqI2PD+z8DHgWeXKzFjzPIyyMv4HcBPeh/1XwA8kJn/tSxZGWOWnc5iz8zfA399nmOK/qoqJwXtIytvWm2eN8hmkiqmynFr/qmKK19WeeU1b3+lUPdFlX6qktvahofqmoKu103U3n6qc0Hlo87bWil0qXzY3WWNMRa7Ma1gsRvTCBa7MY1gsRvTCBa7MY0w1O6yZ8+eLVonNetNdSpVFo+y3mobE6qxpS65oEspVQy0zaPWSB1XddgFuPTSSzvNqWwe6N51VT2eW7dulXPu2LGjGLviiiuKMXU/a+fJ0aNHO41V1ptaO6jbgf3wM7sxjWCxG9MIFrsxjWCxG9MIFrsxjWCxG9MIQ9/YsWS9qU6k748toaqZVNWRsvNAVx6dOHGiGBukmmnbtm3FmLLIVHdUNQ66dxA6ffq0jKu1V9al6jxbsxG3b99ejO3Zs6cYU+fXsWPH5JwqX1UBqezSmvVWOq7Sgp/ZjWkEi92YRrDYjWkEi92YRrDYjWkEi92YRhiq9aYaTiorq0atYq5Erfmjsj+UVagsvQsu0Euu7JhNmzYVYxs2bCjGuq4PaBtR2WegG0eqDRjVuFqTULUOKqYeM2WfgV5fFVNVlTWLttRE9Fe/+lVxjJ/ZjWkEi92YRrDYjWkEi92YRrDYjWkEi92YRrDYjWmEqs8eEfcBnwRmMvPDvdvGgR8Ce4CXgdsy882lTFjytlU5JGjfUZb1dexKC7p7qppT+eHj4+NyTlWqum7dumJMlanWSnnVZolTU1PF2OTkpDyuKg1VMbW2tXJcdR513eSzViKsHrNdu3YVY12776qxDzzwQHHMUp7ZvwvceM5tXwYezcxrgEd7vxtjVjFVsWfmY8C5l0rdDNzf+/l+4JZlzssYs8x0vVx2R2ZOAWTmVEQU24NExD5gH9QvFTXGrBwr/gFdZu7PzInMnLDYjRkdXcU+HRE7AXrfZ5YvJWPMStBV7A8Dd/R+vgP46fKkY4xZKZZivT0IfATYGhGTwFeBe4AfRcRngVeATy1lsogolgvWShdV6aeylpQVU7M3lPWm3pKoksjaxoRdu8uqOWvlksoGU2Wss7Oz8riqVFWVNNesLkXXc6G2RgplByqrVY2r5VO6L+o8qIo9M28vhD5aG2uMWT34CjpjGsFiN6YRLHZjGsFiN6YRLHZjGmGol7RFRNGyqm3sqFDVa8puUdZaLa4qs1R1Wm1jQmU7KatmbGysGFO51o6rrKOaPaTWQaE6sioLFvT6qjVSXXRPnTol51TrUOoCW8una0WmOmf9zG5MI1jsxjSCxW5MI1jsxjSCxW5MI1jsxjTC0Dd2LFlhg1QdrVRTDGUBqSo9ZQ8puwW0VdO1eWbNYlTW2+WXX16MbdmyRR5XVdMplA1Wa57Z1bZTzShPnz4t51TxN98s92EtbXIKdeutZKeqJpZ+ZjemESx2YxrBYjemESx2YxrBYjemESx2YxrBYjemEYbqs69Zs6ZYwlkrXVTlkqqjpvIda/6p8mxVF1h1X2r+qep4q/xyddzaNQxqjdQ1DBs2bJDHVfmqx1OVlNYeM+WXd+3AW7uOQ5VnqxJr1WG3VpZcWnt5nYY8ojHmzwaL3ZhGsNiNaQSL3ZhGsNiNaQSL3ZhGWMrGjvcBnwRmMvPDvdvuBv4ZeH9nv69k5iO1Y1188cVce+21fWM1G0d1XVXlpsrKOnLkiJxTlVoqW07ZQ2qjxNqcyo5R5ZI1u0rZQ8rKqXUEVnaWsp1OnjxZjNVKXGdmyruHqzVSsRrKRlT5qlhND9u3bz//+eQRF/gucGOf27+Zmdf1vqpCN8aMlqrYM/MxQD8dGWNWPYO8Z78zIp6OiPsiYvOyZWSMWRG6iv3bwNXAdcAU8PXSH0bEvog4EBEH1PtRY8zK0knsmTmdmfOZeRb4DrBX/O3+zJzIzAnV68wYs7J0EntE7Fz0663As8uTjjFmpViK9fYg8BFga0RMAl8FPhIR1wEJvAx8bimTbdy4kU984hN9Y7VNAFUlmRp79OjRYuyFF16Qc7744ovF2PHjx4sxZXXVXt2oyiwVUxVmg1SKqbdeNbtKWW/qcVF2qarQA21tqnVQ+aj1AW39du1CrNYAyla0rGCURwQy8/Y+N99bG2eMWV34CjpjGsFiN6YRLHZjGsFiN6YRLHZjGsFiN6YRhtpddt26dezd2/9iu1oHT1XeqcaqMsxDhw7JOZX3qvxclavyykHnu3HjxmJM+fc1b/qtt94qxlQZsBoH2odXHWTV+tXKatXOsepaBFVyq3ZiBb1GqjRbxdT6QPm+KH/ez+zGNILFbkwjWOzGNILFbkwjWOzGNILFbkwjDNV6W7t2LTt27OgbUx1OlxIvoWynmqXy+uuvF2PKeuvaURT0Bo1bt24txjZvLncGq20mqeyq6enpYkyV+UL3UtVNmzYVY4OcJ2pOZV0qyw506ax6vNW4mq1ZOq4qK/YzuzGNYLEb0wgWuzGNYLEb0wgWuzGNYLEb0whDtd4UyjIA3eVU2S3KXlNVbaAtKbX5oLK6atVM6riqukrZcjVUxdfs7GwxVrsvqupNdV297LLLirHahodqw02FykfZmjWU9abWr1YFqjbcLOZy3iOMMX+SWOzGNILFbkwjWOzGNILFbkwjWOzGNMJSNnbcDXwPuAw4C+zPzG9FxDjwQ2APC5s73paZsowsM4sWW22TQLW5nrIwlPVWq2ZSDR5VRdcgzQuV9abWSFmTtUo7tbbqvtQaWSqURaaq3rZs2SKPq5o4qnxVI9Ca9abmVGuvKgprlYql48qKS3nEBc4AX8rMvwT+FvhCRPwV8GXg0cy8Bni097sxZpVSFXtmTmXmk72fTwAHgV3AzcD9vT+7H7hlpZI0xgzOeb1nj4g9wPXAb4AdmTkFC/8QgO3LnZwxZvlYstgj4hLgIeCuzNQtSv543L6IOBARB9TlnsaYlWVJYo+ItSwI/fuZ+ePezdMRsbMX3wnM9BubmfszcyIzJwa5ftsYMxhVscfCFff3Agcz8xuLQg8Dd/R+vgP46fKnZ4xZLpZS9XYD8BngmYh4qnfbV4B7gB9FxGeBV4BPrUyKxpjloCr2zPw1UKqn++j5TJaZxXLUmq+oyli7lr+uX79eznn55ZcXY2NjY8WY8qZrJZiqa63yidU1A+qagFpc+bbr1q2Tx1X3VfnlqsR19+7dck7liavrMdSGkbVyU7UO6riqu2ztupPSuTCoz26M+TPAYjemESx2YxrBYjemESx2YxrBYjemEYbeXbZkGdTKTdesWVOMKbtBdeGslX6qskdVhqk6laqyWdAdZNVmf6o7r7ICQa/tIKWfXddBbbJYK6tV5bpqrFqDmnWpypKVRavuZ82iLdl2tt6MMRa7Ma1gsRvTCBa7MY1gsRvTCBa7MY0wVOtNdZetbeyoqoCOHy83zlGxWuccNVbZMcpuUZYdaKtQWTXKXlPVVaCrtpQFpKy1WlxVkqn7Mjk5KedUx1X3Uz1mNYtWbRCq1m98fLwYq1m0pQ0ubb0ZYyx2Y1rBYjemESx2YxrBYjemESx2YxphqNbb/Px80c6qVWapDRFfeumlYuzQoUPF2GuvvSbnnJ2dLcZUdZWyz2pNLpV1opoXKpurtpmksqvUcQexEVUjUGWJqoacoNdPWV3KIqtVZKpmqWptVUPTq666Ss6p7ksJP7Mb0wgWuzGNYLEb0wgWuzGNYLEb0wgWuzGNsJRdXHdHxC8j4mBEPBcRX+zdfndEvBYRT/W+blr5dI0xXVmKz34G+FJmPhkRG4AnIuIXvdg3M/NrS51sbm6O6enpvjFVJggwNTVVjD3//POdxqluraDLalXZqCrXrXm2yj9VJZpdu++C9oJVd9laiataP7XJonpc1P0Ena+KqXxqJcJqrMpXXWtQK3EtHVd10F3KLq5TwFTv5xMRcRDYVRtnjFldnNd79ojYA1wP/KZ3050R8XRE3BcRuom4MWakLFnsEXEJ8BBwV2YeB74NXA1cx8Iz/9cL4/ZFxIGIOHDs2LFlSNkY04UliT0i1rIg9O9n5o8BMnM6M+cz8yzwHWBvv7GZuT8zJzJzonYttTFm5VjKp/EB3AsczMxvLLp956I/uxV4dvnTM8YsF0v5NP4G4DPAMxHxVO+2rwC3R8R1QAIvA59bkQyNMcvCUj6N/zXQr1bxkfOdbG5urlhWWitdPHz4cDGmxqrSWVWCCdpaUjbOO++8U4ypLqagLRcVU5sLqlxBl2gq60h1uwVt+Sl7Ulldyias5aTWQdlgqssw1M/dLvmo8x3K+ap19RV0xjSCxW5MI1jsxjSCxW5MI1jsxjSCxW5MIwy1u+zc3Byvv/5635jq1graGlEdW9U4VZUF2jpSY1XVlrK5QNuByrZTNuH27dvlnOq+dJ0TdPdZVd2nKiBrj5my5lR3XnU/a9WR6jFV55/aHLRWBVo6T2y9GWMsdmNawWI3phEsdmMawWI3phEsdmMaYajWW2YWrZNaU0Rl86jKLGW31DaTVDaGsluUfaaOCdqCVNVgag22bNki51RNMLs2o6zFVWzr1q3FmKooBL1+yl5Tj1ltE0X1mKpzTDWH7LoZpzqmn9mNaQSL3ZhGsNiNaQSL3ZhGsNiNaQSL3ZhGsNiNaYSh+uxQ9oNrnV6Vr628dNVttFZWqzxdVZ5YK2NVqJJI5dkqD7l2DYMaqzaTrHWXVY+L6pR74YUXFmO180SVo6oOvOrxVCXUANu2bSvG1Pqp6xuUXw7lc1Ode35mN6YRLHZjGsFiN6YRLHZjGsFiN6YRLHZjGiHUx//LPlnELPB/i27aChwZWgJ1nI9mteUDqy+nUefzF5nZ1wscqtg/MHnEgcycGFkC5+B8NKstH1h9Oa22fBbjl/HGNILFbkwjjFrs+0c8/7k4H81qywdWX06rLZ8/MNL37MaY4THqZ3ZjzJAYidgj4saI+J+IeCkivjyKHM7J5+WIeCYinoqIAyPK4b6ImImIZxfdNh4Rv4iIF3vfN484n7sj4rXeOj0VETcNMZ/dEfHLiDgYEc9FxBd7t49kjUQ+I1ujGkN/GR8Ra4AXgI8Bk8DjwO2Z+buhJvLHOb0MTGTmyPzRiPh74CTwvcz8cO+2fwPeyMx7ev8UN2fmv4wwn7uBk5n5tWHkcE4+O4GdmflkRGwAngBuAf6JEayRyOc2RrRGNUbxzL4XeCkzf5+Z7wE/AG4eQR6risx8DHjjnJtvBu7v/Xw/CyfTKPMZGZk5lZlP9n4+ARwEdjGiNRL5rFpGIfZdwKuLfp9k9IuUwM8j4omI2DfiXBazIzOnYOHkAvRG68Phzoh4uvcyf2hvKxYTEXuA64HfsArW6Jx8YBWsUT9GIfZ+rUZGbQnckJl/A/wj8IXeS1jzQb4NXA1cB0wBXx92AhFxCfAQcFdmHh/2/EvIZ+RrVGIUYp8Edi/6/Qrg8Ajy+AOZebj3fQb4CQtvNVYD0733hu+/R5wZZTKZOZ2Z85l5FvgOQ16niFjLgrC+n5k/7t08sjXql8+o10gxCrE/DlwTEVdGxBjwaeDhEeQBQESs733AQkSsBz4OPKtHDY2HgTt6P98B/HSEubwvpve5lSGuUyw0n7sXOJiZ31gUGskalfIZ5RpVycyhfwE3sfCJ/P8C/zqKHBblchXw297Xc6PKB3iQhZd9cyy8+vkssAV4FHix9318xPn8J/AM8DQLIts5xHz+joW3e08DT/W+bhrVGol8RrZGtS9fQWdMI/gKOmMawWI3phEsdmMawWI3phEsdmMawWI3phEsdmMawWI3phH+H8qUSJtz+w/4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 可以手动设置权重\n",
    "\n",
    "with torch.no_grad():\n",
    "    conv.bias.zero_()          # bias 设为0，去除干扰\n",
    "with torch.no_grad():\n",
    "    conv.weight.fill_(1.0/9.0)   # 每个元素都设置为1/9，计算邻域内所有元素的均值\n",
    "\n",
    "output = conv(img.unsqueeze(0))\n",
    "\n",
    "plt.imshow(output[0,0].detach(),cmap='gray')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 16, 30, 30])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv = nn.Conv2d(3,1,kernel_size=3,padding=1)\n",
    "\n",
    "with torch.no_grad():\n",
    "    conv.bias.zero_()\n",
    "    conv.weight[:]=torch.tensor([[-1.0,0.0,1.0],\n",
    "                                 [-1.0,0.0,1.0],\n",
    "                                 [-1.0,0.0,1.0]])\n",
    "    \n",
    "    \n",
    "output = conv(img.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.3678, -5.2074,  ..., -1.6409,  1.2073],\n",
       "        [-0.6942, -5.7819,  ..., -2.9649,  2.7501],\n",
       "        ...,\n",
       "        [ 2.5885, -0.6261,  ...,  4.7024,  4.4315],\n",
       "        [ 2.4069, -0.5565,  ...,  3.1238,  3.0363]], grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output[0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f68144a6810>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAANx0lEQVR4nO3df8ydZX3H8fdnFPAPmaDdRi1FIGvc1C0RG4a6LM2QBBpDl8gS/EPBQApOMk1cMtQEEpNl6h8uIxBJFSIsBsjA4ONSY2DAcFlgFFIo0DAKycKTNqLgikSnq/vuj+dmOzs9z49e5z4/qu9XcnLuH9e5ry9Xk0+v+xdNVSFJR+vXZl2ApGOT4SGpieEhqYnhIamJ4SGpieEhqclY4ZHkzUnuTfJc933KMu1+kWRP91kYp09J8yHjPOeR5EvAK1X1hSTXAKdU1V+OaPdaVb1xjDolzZlxw+NZYGtVHUyyAXiwqt4+op3hIf2SGTc8/qOqTh5Y/1FVHXHqkuQwsAc4DHyhqu5Z5ng7gB3d6nuaC/sVsGHDhlmXMPfe+ta3zrqEuffYY4/9sKp+o+W361ZrkOQ+4NQRuz53FP2cXlUHkpwF3J9kb1U9P9yoqnYCO7t+fW5+BVdeeeWsS5h711133axLmHtJ/r31t6uGR1V9YIWOv59kw8Bpy0vLHONA9/1CkgeBdwNHhIekY8e4t2oXgEu75UuBbw03SHJKkhO75fXA+4FnxuxX0oyNGx5fAM5P8hxwfrdOki1Jvta1+V1gd5IngAdYuuZheEjHuFVPW1ZSVS8D543Yvhu4olv+F+D3xulH0vzxCVNJTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNegmPJBckeTbJ/iTXjNh/YpI7u/2PJDmjj34lzc7Y4ZHkOOBG4ELgHcCHk7xjqNnlwI+q6reBvwG+OG6/kmarj5nHOcD+qnqhqn4O3AFsH2qzHbi1W74LOC9Jeuhb0oz0ER4bgRcH1he7bSPbVNVh4BDwlh76ljQj63o4xqgZRDW0IckOYEcPNUmasD5mHovApoH104ADy7VJsg54E/DK8IGqamdVbamqLT3UJWmC+giPR4HNSc5McgJwCbAw1GYBuLRbvhi4v6qOmHlIOnaMfdpSVYeTXA18FzgOuKWqnk7yeWB3VS0ANwN/l2Q/SzOOS8btV9Js9XHNg6raBewa2nbtwPJ/An/aR1+S5oNPmEpqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhq0kt4JLkgybNJ9ie5ZsT+y5L8IMme7nNFH/1Kmp114x4gyXHAjcD5wCLwaJKFqnpmqOmdVXX1uP1Jmg99zDzOAfZX1QtV9XPgDmB7D8eVNMf6CI+NwIsD64vdtmEfSvJkkruSbBp1oCQ7kuxOsruHuiRNUB/hkRHbamj928AZVfX7wH3AraMOVFU7q2pLVW3poS5JE9RHeCwCgzOJ04ADgw2q6uWq+lm3+lXgPT30K2mG+giPR4HNSc5McgJwCbAw2CDJhoHVi4B9PfQraYbGvttSVYeTXA18FzgOuKWqnk7yeWB3VS0Af57kIuAw8Apw2bj9SpqtscMDoKp2AbuGtl07sPwZ4DN99CVpPviEqaQmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCa9hEeSW5K8lOSpZfYnyfVJ9id5MsnZffQraXb6mnl8Hbhghf0XApu7zw7gKz31K2lGegmPqnoIeGWFJtuB22rJw8DJSTb00bek2ZjWNY+NwIsD64vdtv8nyY4ku5PsnlJdkhqtm1I/GbGtjthQtRPYCZDkiP2S5se0Zh6LwKaB9dOAA1PqW9IETCs8FoCPdnddzgUOVdXBKfUtaQJ6OW1JcjuwFVifZBG4DjgeoKpuAnYB24D9wE+Aj/XRr6TZ6SU8qurDq+wv4BN99CVpPviEqaQmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCa9hEeSW5K8lOSpZfZvTXIoyZ7uc20f/UqanV7+oWvg68ANwG0rtPleVX2wp/4kzVgvM4+qegh4pY9jSTo2TPOax3uTPJHkO0neOapBkh1JdifZPcW6JDXo67RlNY8Db6uq15JsA+4BNg83qqqdwE6AJDWl2iQ1mMrMo6perarXuuVdwPFJ1k+jb0mTMZXwSHJqknTL53T9vjyNviVNRi+nLUluB7YC65MsAtcBxwNU1U3AxcDHkxwGfgpcUlWelkjHsF7Co6o+vMr+G1i6lSvpl4RPmEpqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGoydngk2ZTkgST7kjyd5JMj2iTJ9Un2J3kyydnj9itptvr4h64PA5+uqseTnAQ8luTeqnpmoM2FwObu8wfAV7pvSceosWceVXWwqh7vln8M7AM2DjXbDtxWSx4GTk6yYdy+Jc1Or9c8kpwBvBt4ZGjXRuDFgfVFjgwYSceQPk5bAEjyRuBu4FNV9erw7hE/qRHH2AHs6KsmSZPTS3gkOZ6l4PhGVX1zRJNFYNPA+mnAgeFGVbUT2Nkd84hwkTQ/+rjbEuBmYF9VfXmZZgvAR7u7LucCh6rq4Lh9S5qdPmYe7wc+AuxNsqfb9lngdICqugnYBWwD9gM/AT7WQ7+SZmjs8Kiqf2b0NY3BNgV8Yty+JM0PnzCV1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1GTs8EiyKckDSfYleTrJJ0e02ZrkUJI93efacfuVNFvrejjGYeDTVfV4kpOAx5LcW1XPDLX7XlV9sIf+JM2BsWceVXWwqh7vln8M7AM2jntcSfMtVdXfwZIzgIeAd1XVqwPbtwJ3A4vAAeAvqurpEb/fAezoVt8FPNVbcf1YD/xw1kUMsJ6VzVs9MH81vb2qTmr5YW/hkeSNwD8Bf1VV3xza9+vAf1fVa0m2AX9bVZtXOd7uqtrSS3E9mbearGdl81YPzF9N49TTy92WJMezNLP4xnBwAFTVq1X1Wre8Czg+yfo++pY0G33cbQlwM7Cvqr68TJtTu3YkOafr9+Vx+5Y0O33cbXk/8BFgb5I93bbPAqcDVNVNwMXAx5McBn4KXFKrny/t7KG2vs1bTdazsnmrB+avpuZ6er1gKulXh0+YSmpieEhqMjfhkeTNSe5N8lz3fcoy7X4x8Jj7wgTquCDJs0n2J7lmxP4Tk9zZ7X+ke7ZlotZQ02VJfjAwLldMsJZbkryUZOQzOFlyfVfrk0nOnlQtR1HT1F6PWOPrGlMdo4m9QlJVc/EBvgRc0y1fA3xxmXavTbCG44DngbOAE4AngHcMtfkz4KZu+RLgzgmPy1pqugy4YUp/Tn8EnA08tcz+bcB3gADnAo/MQU1bgX+Y0vhsAM7ulk8C/m3En9dUx2iNNR31GM3NzAPYDtzaLd8K/MkMajgH2F9VL1TVz4E7uroGDdZ5F3De67ehZ1jT1FTVQ8ArKzTZDtxWSx4GTk6yYcY1TU2t7XWNqY7RGms6avMUHr9VVQdh6T8W+M1l2r0hye4kDyfpO2A2Ai8OrC9y5CD/b5uqOgwcAt7Scx1HWxPAh7op8F1JNk2wntWstd5pe2+SJ5J8J8k7p9Fhd0r7buCRoV0zG6MVaoKjHKM+nvNYsyT3AaeO2PW5ozjM6VV1IMlZwP1J9lbV8/1UyKgZxPC97LW06dNa+vs2cHtV/SzJVSzNjP54gjWtZNrjsxaPA2+r/3s94h5gxdcjxtW9rnE38KkaeM/r9d0jfjLxMVqlpqMeo6nOPKrqA1X1rhGfbwHff33q1n2/tMwxDnTfLwAPspSifVkEBv/WPo2lF/lGtkmyDngTk50yr1pTVb1cVT/rVr8KvGeC9axmLWM4VTXl1yNWe12DGYzRJF4hmafTlgXg0m75UuBbww2SnJLkxG55PUtPtw7/f0PG8SiwOcmZSU5g6YLo8B2dwTovBu6v7orThKxa09D58kUsndPOygLw0e6OwrnAoddPR2dlmq9HdP2s+LoGUx6jtdTUNEbTuAK9xivCbwH+EXiu+35zt30L8LVu+X3AXpbuOOwFLp9AHdtYuhr9PPC5btvngYu65TcAfw/sB/4VOGsKY7NaTX8NPN2NywPA70ywltuBg8B/sfQ36OXAVcBV3f4AN3a17gW2TGF8Vqvp6oHxeRh43wRr+UOWTkGeBPZ0n22zHKM11nTUY+Tj6ZKazNNpi6RjiOEhqYnhIamJ4SGpieEhqYnhIamJ4SGpyf8AR2/geByTsNMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(conv.weight[0,0].detach().numpy(),cmap='gray')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 3, 32, 32]), torch.Size([1, 3, 16, 16]))"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool = nn.MaxPool2d(2)\n",
    "output = pool(img.unsqueeze(0))\n",
    "img.unsqueeze(0).shape,output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "            nn.Conv2d(3,16,kernel_size=3,padding=1),\n",
    "            nn.Tanh(),\n",
    "            nn.MaxPool2d(2),\n",
    "            nn.Conv2d(16,8,kernel_size=3,padding=1),\n",
    "            nn.Tanh(),\n",
    "            nn.MaxPool2d(2),\n",
    "            # \n",
    "            nn.Linear(8*8*8,32),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(32,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "##  建立nn.Model的子类\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3,16,kernel_size=3,padding=1)\n",
    "        self.act1 = nn.Tanh()\n",
    "        self.pool1 = nn.MaxPool2d(2)\n",
    "        self.conv2 = nn.Conv2d(16,8,kernel_size=3,padding=1)\n",
    "        self.act2 = nn.Tanh()\n",
    "        self.pool2 = nn.MaxPool2d(2)\n",
    "        self.fc1 = nn.Linear(8*8*8,32)\n",
    "        self.act3 = nn.Tanh()\n",
    "        self.fc2 = nn.Linear(32,2)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        out = self.pool1(self.act1(self.conv1(x)))\n",
    "        out = self.pool2(self.act2(self.conv2(out)))\n",
    "        out = out.view(-1,8*8*8)\n",
    "        out = self.act3(self.fc1(out))\n",
    "        out = self.fc2(out)\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(18090, [432, 16, 1152, 8, 16384, 32, 64, 2])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Net()       # 实例化后，可以追踪参数\n",
    "\n",
    "numel_list = [p.numel() for p in model.parameters()]\n",
    "\n",
    "sum(numel_list),numel_list\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 每个nn model  都有对应的函数，可以将没有参数的pool和than转换为函数形式\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3,16,kernel_size=3,padding=1)\n",
    "        self.conv2 = nn.Conv2d(16,8,kernel_size=3,padding=1)\n",
    "        self.fc1 = nn.Linear(8*8*8,32)\n",
    "        self.fc2 = nn.Linear(32,2)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        out = F.max_pool2d(F.tanh(self.conv1(x)),2)  # 2 是max_pool2d 的参数\n",
    "        out = F.max_pool2d(F.tanh(self.conv2(out)),2)\n",
    "        out = out.view(-1,8*8*8)\n",
    "        out = F.tanh(self.fc1(out))\n",
    "        out = self.fc2(out)\n",
    "        \n",
    "        return out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
